import argparse
import pandas as pd
import json
from tqdm import tqdm
import torch
import numpy as np
import os
import random
from dataset.utils import compute_centroid
from attack.greedy_search import GreedySearchOptimizer
from utils import generate_experiment_filename

from models.language_models import Llama2_7b, Gemma7b, Vicuna13b, Mistral7b, PhiMini, Llama3_8b, Qwen7b, Zephyr7bR2D2, Mistral7B_RR, Llama3_8bRR
from models.substitutors import GPTSubstitutor, ModernBertSubstitutor


def parse_args():
    parser = argparse.ArgumentParser(description='Main Script for running LatentBreak')
    
    # Model parameters
    parser.add_argument('--model', type=str, default='llama2-7b',
                        choices=['llama2-7b', 'gemma-7b', 'vicuna-13b', 'mistral7b', 'phi-mini', 'llama3-8b', 'llama3-70b', 'qwen7b', 'r2d2', "mistral7b-rr", "zephyr", "llama3-8b-rr"],
                        help='Model architecture to use')
    parser.add_argument('--main_device', type=str, default='cuda:2',
                        help='Device to run the model on')
    parser.add_argument('--judge', type=bool, default=True,
                        help='Judge evaluates the intent')
    parser.add_argument('--evaluator', type=bool, default=False,
                        help='Judge evaluates the intent')
    parser.add_argument('--evaluator_device', type=str, default='cuda:1',
                        help='Device to run the model on')
    
    # Optimization parameters
    parser.add_argument('--subs', type=int, default=20,
                        help='Number of possible candidates')
    parser.add_argument('--layer', type=int, default=31,
                        help='Layer on which we want to optimize')
    
    # Substitutor parameters
    parser.add_argument('--substitutor', type=str, default='mbert',
                        choices=['gpt', 'mbert'],
                        help='Substitutor type')
    
    # Dataset parameters
    parser.add_argument('--test_file', type=str, 
                        default='./dataset/raw/advbench_pair.csv',
                        help='Path to test file')
    parser.add_argument('--num_samples', type=int, default=2,
                        help='Number of samples to process')
    
    return parser.parse_args()

def get_model(model_name, device):
    models = {
        'llama2-7b': Llama2_7b,
        'gemma-7b': Gemma7b,
        'vicuna-13b': Vicuna13b,
        'mistral7b': Mistral7b,
        'phi-mini': PhiMini,
        'llama3-8b': Llama3_8b,
        'qwen7b': Qwen7b,
        'r2d2': Zephyr7bR2D2,
        'mistral7b-rr': Mistral7B_RR,
        'llama3-8b-rr': Llama3_8bRR,
    }
    return models[model_name](device=device)

def save_result(results, filename):
    with open(filename, 'w') as f:
        json.dump(results, f)

def set_seeds(seed=0):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)

def main():
    args = parse_args()
    # Initialize model
    model = get_model(args.model, args.main_device)

    devices = {"main_device": args.main_device, 
               "evaluator_device": args.evaluator_device}
    
    # Load datasets
    if 'harmbench' in args.test_file:
        df = pd.read_csv(args.test_file)
        prompts = df[df['FunctionalCategory'].str.contains('standard', case=False)]['Behavior']
        data_name = 'harmbench'

    elif 'advbench' in args.test_file:
        df = pd.read_csv(args.test_file)
        prompts = df["goal"]
        data_name = 'advbench'
    
    elif 'harmful_behaviors_pair' in args.test_file:
        df = pd.read_csv(args.test_file)
        prompts = df['goal']
        data_name = 'pair'


    HL_x = torch.load(f"./dataset/representations/{args.model}/HLx_train.pt", weights_only=True).numpy()
    layer = args.layer
    # Compute centroid
    c0 = compute_centroid(HL_x, layer)
    
    # Generate experiment name and plot
    results_file = generate_experiment_filename(
        model_name=args.model, layer=layer,
        optimization_method='greedy', substitutor=args.substitutor, num_substitutions=args.subs, judge=args.judge, data_name=data_name)
    
    results = {}
    
    substitutors = {
        "mbert": ModernBertSubstitutor,
        "gpt": GPTSubstitutor,
    }

    attacker = GreedySearchOptimizer(
            model=model,
            layer=layer,
            target_embedding=c0,
            num_substitutions=args.subs,
            substitutor=substitutors[args.substitutor],
            judge=args.judge,
            device=devices,
            evaluator=args.evaluator
        )

    # Process prompts
    for j, prompt in enumerate(tqdm(prompts[:args.num_samples], desc="Optimizing prompt", unit="iteration")):

        initial_prompt = prompt
        print(f"Processing prompt {j+1}/{args.num_samples}")
        print("Initial prompt: ", initial_prompt)
        
        optimized_prompts, response, evaluation = attacker.optimize(prompt)
        results[str(j)] = {"optimization": optimized_prompts, "response": response, "is_jailbreak_harmbench": evaluation}
        
        # Save results after each optimization
        save_result(results, results_file)
        print(f"Saved results for prompt {j+1}")

if __name__ == "__main__":
    set_seeds()
    main()